/*  Copyright 2009 Marc Toussaint
    email: mtoussai@cs.tu-berlin.de

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a COPYING file of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/> */

#ifndef MT_algos_h
#define MT_algos_h

#include "array.h"
#include "util.h"

namespace MT{

  void normalizeData(arr& X);

  //----- spline stuff
  void makeSpline(arr& X,arr& P,uint intersteps);
  void makeSpline(arr& X,arr& V,arr& P,uint intersteps);
  void randomSpline(arr& X,uint dim,uint points,uint intersteps=100,double lo=-1.,double hi=1.,uint cycles=1);
  void randomSpline(arr& X,arr& dX,uint dim,uint points,uint intersteps=100,double lo=-1.,double hi=1.,uint cycles=1);

  //----- gradient optimization
  void checkGradient(void (*f)(arr&,const arr&,void*),
		     void (*df)(arr&,const arr&,void*),
		     void *data,
		     const arr& x,double tolerance);
  void checkGradient(double (*f)(const arr&,void*),
		     void (*df)(arr&,const arr&,void*),
		     void *data,
		     const arr& x,double tolerance);
  int minimize(double (*f)(const arr&,void*),
               void (*df)(arr&,const arr&,void*),
               void *data,
               arr& x,
               double *fmin_return,
               int method,
               uint maxIterations,
               double stoppingTolerance,
               bool chkGrad);

  //----- LU decomposition
  double determinant_LU(const arr& X);
  void inverse_LU(arr& Xinv,const arr& X);
  void LU_decomposition(arr& L,arr& U,const arr& X);

  //----- Runge-Kutta
  //! standard Runge-Kutta 4
  void rk4(arr& x1,const arr& x0,
	   void (*df)(arr& xd,const arr& x),
	   double dt);
  //! same for second order diff equation
  void rk4dd(arr& x1,arr& v1,const arr& x0,const arr& v0,
    void (*ddf)(arr& xdd,const arr& x,const arr& v),
    double dt);
  /*! RK with discrete event localization (zero-crossing detection):
      the function sf computes some
      double-valued indicators. If one of these indicators crosses zero this is interpreted
      as a discontinuity in the dynamics. The algorithm iteratively tries to find the
      zero-crossing point up to a tolerance tol (measured in time). The routine returns
      false for no-switching and true and the executed time step dt in the case of
      zero-crossing */
  bool rk4_switch(arr& x1,arr& s1,const arr& x0,const arr& s0,
    void (*df)(arr& xd,const arr& x),
    void (*sf)(arr& s,const arr& x),
    double& dt,double tol);
  //! same for 2nd order DEs  
  bool rk4dd_switch(arr& x1,arr& v1,arr& s1,const arr& x0,const arr& v0,const arr& s0,
    void (*ddf)(arr& xdd,const arr& x,const arr& v),
    void (*sf)(arr& s,const arr& x,const arr& v),
    double& dt,double tol);

  //bandpass filtering
  void convolution(arr &y,const arr &x,double (*h)(double),double scale=1.);
  void bandpassFilter(arr &y,const arr &x,double loWavelength,double hiWavelength);
  void bandpassEnergy(arr &y,const arr &x,double loWavelength,double hiWavelength);

  //----- comparing connectivity matrices
  double matdistance(intA& fix,intA& fox,uintA& p,bool sub);
  double matdistance(intA& A,intA& B,bool sub);
  double matannealing(intA& fix,intA& fox,uintA& p,bool sub,double annealingRepetitions,double annealingCooling);
}


//===========================================================================
//
// MonSolver
//

/*! a trivial solver for monotonic unimodal 1D functions */
class MonSolver{
public:
  double min,max;
  int phase;
  MonSolver();
  void init(double& par,double wide=2.);
  void solve(double& par,const double& err);
};


//===========================================================================
//
// Linear Statistics
//

/*! collects means and (co-)variances of data, also input-output data.
    This class is the basis for linear regression techniques, like
    Least Squares, Partial Least Squares, but can also be used on its
    own, simply to average or calculate (co-)variances \ingroup
    regression */
class LinearStatistics{
public:
  double accum;  //!< the accumulation norm (=number of collected data points if not weighted)
  arr meanX,meanY,varX,covXY; //<! these are not normalized or centered bufferes
  arr MeanX; //!< X mean
  arr MeanY; //!< Y mean
  arr VarX;  //!< X variance
  arr CovXY; //!< XY covariance
  double lambda; //!< forgetting rate [default=0]
  bool computed; //!< internal indicator whether recomputation is needed

  LinearStatistics();

  //feed data
  void learn(const arr& x,const arr& y,double weight=1.);
  void learn(const arr& x,double y,double weight);
  void learn(const arr& x);
  void learn(const arr& x,double weight);
  void clear();
  void forget(double lambda=1.);

  //get information
  uint inDim();
  double variance();
  void correlationX(arr& corr,bool clearDiag=false);
  void correlationXY(arr& corr,bool clearDiag=false);
  void mahalanobisMetric(arr& metric);
  void regressor(arr& A);
  void regressor(arr& A,arr& a);
  void predict(const arr& x,arr& y);

  //used internally
  void compute();
  void computeZeroMean();

  //output
  void write(std::ostream& os) const;
};
stdOutPipe(LinearStatistics);




//===========================================================================
//
// Tuple index
//

class TupleIndex:public uintA{
public:
  uintA tri;
  void init(uint k,uint n);
  uint index(uintA i);
  void checkValid();
};


//===========================================================================
//
// Kalman filter
//

/*! a Kalman filter */
class Kalman{
  arr
    A,a, //linear forward transition
    Q, //covariance of forward transition x(t) = A*x(t-1) + \NN(0,Q)
    C, //linear observation matrix
    R; //covariance of observation: y = C*x + \NN(0,R)
  
  void setTransitions(uint d,double varT,double varO);
  void filter(arr& Y,arr& X,arr& V,arr *Rt=0);
  void smooth(arr& Y,arr& X,arr& V,arr *Vxx=0,arr *Rt=0);
  void EMupdate(arr& Y,arr *Rt=0);
  void fb(arr& y,arr& f,arr& F,arr& g,arr& G,arr& p,arr& P,arr *Rt=0);
};


//===========================================================================
//
// X-Splines
//

class XSpline{
public:
  double DELTA;     // Distance between each
  
  arr V; //vertices
  arr W; //weights associated to each vertex
  
  XSpline();
  ~XSpline();
  
  void setWeights(double w);
  void type(bool hit,double smooth);
  void referTo(arr& P);
  arr eval(double t);
  void eval(double t,arr& x,arr* v=0);
  void eval(double t,arr& x,arr& v);
};


//===========================================================================
//
// Partial Least Squares (PLS, SIMPLS, de Jong)
//

/*! An implementation of Partial Least Squares following the SIMPLS algorithm (de Jong).
    This PLS implementation finds an optimal linear regression (from n to m dimensions)
    by first calculating input projections `with higest correlation to the output'. This
    implementation uses the Singular Value Decomposition routine MT::svd and
    the LinearStatistics. \ingroup regression */
class PartialLeastSquares{
public:
  LinearStatistics S; //<! the statistics collected with learning

  arr W,Q,B; //the transformation matricies

  arr resErr; //the residual errors

  MT::Parameter<uint> maxProj;

  PartialLeastSquares():maxProj("PLSmaxProjections",0){ }

  //feed data
  void learn(const arr& x,const arr& y,double weight=1.);
  void learn(const arr& x,double y,double weight=1.);
  void clear();

  //access 
  void map(const arr& x,arr& y);
  void map(const arr& x,double& y);
  double map(const arr& x);
  arr projection(uint k);
  uint inDim();
  uint outDim();

  void write(std::ostream& os) const;

  //internally used
  void SIMPLS();
};



//===========================================================================
//
// helpers to include foreign code
//

double *vector(uint i,uint j);
void nrerror(const char* msg);
void free_vector(double* p,uint i,uint j);


//===========================================================================
//
// implementations
//

#ifdef MT_IMPLEMENTATION
#  include"algos.cpp"
#ifdef MT_algos_extern
#  include"algos_LU.cpp"
#  include"algos_CG.cpp"
//#  include"algos_LM.c"
#  include"algos_rk.cpp"
#endif
#include"algos_rk.cpp"
#endif

#endif
